c81bd2
@@ -111,7 +111,7 @@
public Object process(Node nd, Stack<Node> stack,
       // will result into a vertex with multiple FS or RS operators.
       if (context.childToWorkMap.containsKey(operator)) {
         // if we've seen both root and child, we can bail.
-        
+
         // clear out the mapjoin set. we don't need it anymore.
         context.currentMapJoinOperators.clear();
 
@@ -349,17 +349,20 @@
public Object process(Node nd, Stack<Node> stack,
         } else if (followingWork instanceof UnionWork) {
           // this can only be possible if there is merge work followed by the union
           UnionWork unionWork = (UnionWork) followingWork;
-          int index = getMergeIndex(tezWork, unionWork, rs);
-          // guaranteed to be instance of MergeJoinWork if index is valid
-          BaseWork baseWork = tezWork.getChildren(unionWork).get(index);
-          if (baseWork instanceof MergeJoinWork) {
-            MergeJoinWork mergeJoinWork = (MergeJoinWork) baseWork;
-            // disconnect the connection to union work and connect to merge work
-            followingWork = mergeJoinWork;
-            rWork = (ReduceWork) mergeJoinWork.getMainWork();
+          int index = getFollowingWorkIndex(tezWork, unionWork, rs);
+          if (index != -1) {
+            BaseWork baseWork = tezWork.getChildren(unionWork).get(index);
+            if (baseWork instanceof MergeJoinWork) {
+              MergeJoinWork mergeJoinWork = (MergeJoinWork) baseWork;
+              // disconnect the connection to union work and connect to merge work
+              followingWork = mergeJoinWork;
+              rWork = (ReduceWork) mergeJoinWork.getMainWork();
+            } else {
+              rWork = (ReduceWork) baseWork;
+            }
           } else {
-            throw new SemanticException("Unknown work type found: "
-                + baseWork.getClass().getCanonicalName());
+            throw new SemanticException("Following work not found for the reduce sink: "
+                + rs.getName());
           }
         } else {
           rWork = (ReduceWork) followingWork;
@@ -403,19 +406,13 @@
public Object process(Node nd, Stack<Node> stack,
     return null;
   }
 
-  private int getMergeIndex(TezWork tezWork, UnionWork unionWork, ReduceSinkOperator rs) {
+  private int getFollowingWorkIndex(TezWork tezWork, UnionWork unionWork, ReduceSinkOperator rs) {
     int index = 0;
     for (BaseWork baseWork : tezWork.getChildren(unionWork)) {
-      if (baseWork instanceof MergeJoinWork) {
-        MergeJoinWork mergeJoinWork = (MergeJoinWork) baseWork;
-        int tag = mergeJoinWork.getMergeJoinOperator().getTagForOperator(rs);
-        if (tag != -1) {
-          return index;
-        } else {
-          index++;
-        }
-      } else {
+      if (tezWork.getEdgeProperty(unionWork, baseWork).equals(TezEdgeProperty.EdgeType.CONTAINS)) {
         index++;
+      } else {
+        return index;
       }
     }
 
